Lab3 Assignment¶

Univariate Linear Regression on the Scikit-Learn Diabetes Dataset¶

Library Imports¶
In [15]:
# Import necessary libraries
from sklearn.model_selection import train_test_split 
from sklearn.linear_model import LinearRegression
from sklearn import metrics
import matplotlib.pyplot as plt
import seaborn as sns
import plotly
import plotly.graph_objects as go
plotly.offline.init_notebook_mode()
Importing Datasets¶
In [16]:
from sklearn import datasets
diabetes_X, diabetes_y = datasets.load_diabetes(return_X_y=True)
print(diabetes_X)
print(diabetes_y)
[[ 0.03807591  0.05068012  0.06169621 ... -0.00259226  0.01990749
  -0.01764613]
 [-0.00188202 -0.04464164 -0.05147406 ... -0.03949338 -0.06833155
  -0.09220405]
 [ 0.08529891  0.05068012  0.04445121 ... -0.00259226  0.00286131
  -0.02593034]
 ...
 [ 0.04170844  0.05068012 -0.01590626 ... -0.01107952 -0.04688253
   0.01549073]
 [-0.04547248 -0.04464164  0.03906215 ...  0.02655962  0.04452873
  -0.02593034]
 [-0.04547248 -0.04464164 -0.0730303  ... -0.03949338 -0.00422151
   0.00306441]]
[151.  75. 141. 206. 135.  97. 138.  63. 110. 310. 101.  69. 179. 185.
 118. 171. 166. 144.  97. 168.  68.  49.  68. 245. 184. 202. 137.  85.
 131. 283. 129.  59. 341.  87.  65. 102. 265. 276. 252.  90. 100.  55.
  61.  92. 259.  53. 190. 142.  75. 142. 155. 225.  59. 104. 182. 128.
  52.  37. 170. 170.  61. 144.  52. 128.  71. 163. 150.  97. 160. 178.
  48. 270. 202. 111.  85.  42. 170. 200. 252. 113. 143.  51.  52. 210.
  65. 141.  55. 134.  42. 111.  98. 164.  48.  96.  90. 162. 150. 279.
  92.  83. 128. 102. 302. 198.  95.  53. 134. 144. 232.  81. 104.  59.
 246. 297. 258. 229. 275. 281. 179. 200. 200. 173. 180.  84. 121. 161.
  99. 109. 115. 268. 274. 158. 107.  83. 103. 272.  85. 280. 336. 281.
 118. 317. 235.  60. 174. 259. 178. 128.  96. 126. 288.  88. 292.  71.
 197. 186.  25.  84.  96. 195.  53. 217. 172. 131. 214.  59.  70. 220.
 268. 152.  47.  74. 295. 101. 151. 127. 237. 225.  81. 151. 107.  64.
 138. 185. 265. 101. 137. 143. 141.  79. 292. 178.  91. 116.  86. 122.
  72. 129. 142.  90. 158.  39. 196. 222. 277.  99. 196. 202. 155.  77.
 191.  70.  73.  49.  65. 263. 248. 296. 214. 185.  78.  93. 252. 150.
  77. 208.  77. 108. 160.  53. 220. 154. 259.  90. 246. 124.  67.  72.
 257. 262. 275. 177.  71.  47. 187. 125.  78.  51. 258. 215. 303. 243.
  91. 150. 310. 153. 346.  63.  89.  50.  39. 103. 308. 116. 145.  74.
  45. 115. 264.  87. 202. 127. 182. 241.  66.  94. 283.  64. 102. 200.
 265.  94. 230. 181. 156. 233.  60. 219.  80.  68. 332. 248.  84. 200.
  55.  85.  89.  31. 129.  83. 275.  65. 198. 236. 253. 124.  44. 172.
 114. 142. 109. 180. 144. 163. 147.  97. 220. 190. 109. 191. 122. 230.
 242. 248. 249. 192. 131. 237.  78. 135. 244. 199. 270. 164.  72.  96.
 306.  91. 214.  95. 216. 263. 178. 113. 200. 139. 139.  88. 148.  88.
 243.  71.  77. 109. 272.  60.  54. 221.  90. 311. 281. 182. 321.  58.
 262. 206. 233. 242. 123. 167.  63. 197.  71. 168. 140. 217. 121. 235.
 245.  40.  52. 104. 132.  88.  69. 219.  72. 201. 110.  51. 277.  63.
 118.  69. 273. 258.  43. 198. 242. 232. 175.  93. 168. 275. 293. 281.
  72. 140. 189. 181. 209. 136. 261. 113. 131. 174. 257.  55.  84.  42.
 146. 212. 233.  91. 111. 152. 120.  67. 310.  94. 183.  66. 173.  72.
  49.  64.  48. 178. 104. 132. 220.  57.]
Traning and Testing Sets¶
In [17]:
# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(diabetes_X[:, 2], diabetes_y, test_size=0.2, random_state=42)
X_train1 = X_train.reshape(-1, 1)
X_test1 = X_test.reshape(-1, 1)
Model Training¶
In [18]:
# Train the model
model = LinearRegression()
model.fit(X_train.reshape(-1, 1), y_train)
Out[18]:
LinearRegression()
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
LinearRegression()
Prediction¶
In [19]:
# Make predictions using the testing set
y_pred = model.predict(X_test1)
y_pred1 = model.predict(X_train1)
print(y_pred)
print(y_pred1)
[145.80622687 188.85739048 147.95878505 203.92529774 131.8145987
 127.50948234 322.31599764 197.4676232   61.85645785 167.33180868
 118.89924962  94.14483055  90.91599328 166.25552959  96.29738873
 157.64529687 223.29832136 240.5187868  180.24715776 210.38297228
 191.00994866 109.21273781 102.75506327 174.86576231 196.39134411
 166.25552959 211.45925137 133.96715688  78.0006442  130.73831961
 244.82390316 114.59413326 166.25552959 145.80622687 192.08622775
 229.7559959  121.0518078  118.89924962 121.0518078   94.14483055
  82.30576056 122.12808689 129.66204052 118.89924962 107.06017963
 116.74669144 115.67041235 101.67878418  67.2378533  153.34018051
 210.38297228  82.30576056 169.48436686 111.36529599 133.96715688
 216.84064682 105.98390054 213.61180955 133.96715688  97.37366782
 182.39971594 193.16250684 206.07785592 107.06017963  86.61087692
 170.56064595 140.42483142 127.50948234 117.82297053 139.34855233
 133.96715688 182.39971594 130.73831961 141.50111051  90.91599328
 112.44157508 212.53553046 191.00994866 171.63692504 131.8145987
 147.95878505 140.42483142  87.68715601  91.99227237  86.61087692
 126.43320325  88.7634351   82.30576056 187.78111139]
[164.10297141 133.96715688 201.77273956 116.74669144  86.61087692
 112.44157508 172.71320413 221.14576318 189.93366957 135.04343597
  83.38203965 171.63692504  94.14483055 256.66297315  72.61924875
 172.71320413 111.36529599 116.74669144 108.13645872 171.63692504
 178.09459958 115.67041235 194.23878593 101.67878418 207.15413501
 166.25552959 140.42483142 141.50111051 142.5773896  147.95878505
 167.33180868 175.9420414  114.59413326 119.97552871 170.56064595
 136.11971506 145.80622687 160.87413414 114.59413326 142.5773896
  74.77180693 150.11134323 107.06017963 197.4676232  184.55227412
 237.28994953 141.50111051 262.0443686  145.80622687 153.34018051
 266.34948496 111.36529599 169.48436686 211.45925137 105.98390054
 199.62018138 157.64529687 121.0518078  206.07785592 146.88250596
 142.5773896  177.01832049 280.34111313 166.25552959 151.18762232
 121.0518078  156.56901778 161.95041323 121.0518078  113.51785417
 154.4164596  126.43320325  85.53459783 112.44157508 104.90762145
 119.97552871 154.4164596  197.4676232   97.37366782 113.51785417
 136.11971506 140.42483142  62.93273694 126.43320325 174.86576231
 213.61180955 144.72994778 231.90855408 212.53553046 208.2304141
 126.43320325 156.56901778 158.72157596 124.28064507 104.90762145
 147.95878505 112.44157508  86.61087692 222.22204227 179.17087867
 175.9420414   68.31413239 223.29832136 193.16250684 129.66204052
 223.29832136  93.06855146 221.14576318 158.72157596 168.40808777
 206.07785592 110.2890169  192.08622775 220.06948409  93.06855146
 103.83134236 127.50948234 224.37460045  70.46669057 244.82390316
 196.39134411 156.56901778  96.29738873 105.98390054 122.12808689
 186.7048323  138.27227324 161.95041323 157.64529687 103.83134236
 117.82297053 126.43320325 240.5187868  158.72157596 122.12808689
 212.53553046 163.02669232 198.54390229 232.98483317 185.62855321
 153.34018051 157.64529687 137.19599415 191.00994866 256.66297315
 279.26483404 216.84064682 122.12808689 144.72994778 199.62018138
 143.65366869 211.45925137 118.89924962  89.83971419 128.58576143
 139.34855233  76.92436511 142.5773896  110.2890169  168.40808777
 240.5187868  154.4164596  207.15413501 104.90762145 116.74669144
 128.58576143 192.08622775 141.50111051 173.78948322 172.71320413
 136.11971506 160.87413414 138.27227324 228.67971681 111.36529599
 158.72157596 188.85739048 127.50948234 132.89087779  70.46669057
 155.49273869 127.50948234 152.26390142 153.34018051 156.56901778
  85.53459783 118.89924962 143.65366869 150.11134323 113.51785417
 201.77273956 122.12808689  98.44994691 213.61180955 220.06948409
  71.54296966 127.50948234  99.526226   223.29832136 128.58576143
 194.23878593 195.31506502 123.20436598 121.0518078  130.73831961
 131.8145987  164.10297141  80.15320238 185.62855321 113.51785417
  94.14483055 168.40808777 226.52715863  95.22110964 237.28994953
 143.65366869  84.45831874 115.67041235 178.09459958 146.88250596
 208.2304141  113.51785417 115.67041235 312.62948583 121.0518078
 169.48436686 105.98390054 161.95041323  81.22948147 105.98390054
 146.88250596 129.66204052 245.90018225 156.56901778  79.07692329
 277.11227586 151.18762232  94.14483055 170.56064595 211.45925137
 123.20436598 223.29832136 143.65366869 135.04343597 182.39971594
 149.03506414 263.12064769 144.72994778 208.2304141   89.83971419
 213.61180955 143.65366869 126.43320325 225.45087954 126.43320325
 141.50111051 143.65366869 274.95971768 201.77273956 187.78111139
 127.50948234 212.53553046 139.34855233 158.72157596 197.4676232
 155.49273869 191.00994866 249.12901952 250.20529861 121.0518078
 206.07785592 186.7048323  212.53553046 119.97552871  96.29738873
 215.76436773 128.58576143 183.47599503 139.34855233 153.34018051
 115.67041235 119.97552871 136.11971506 100.60250509 117.82297053
 166.25552959 164.10297141 198.54390229  88.7634351   74.77180693
 288.95134585 203.92529774 144.72994778 142.5773896  149.03506414
 180.24715776 143.65366869 150.11134323 131.8145987  111.36529599
 167.33180868 131.8145987   91.99227237 218.993205   153.34018051
 195.31506502 150.11134323 163.02669232 156.56901778 105.98390054
 177.01832049 127.50948234 144.72994778 161.95041323  85.53459783
 136.11971506  97.37366782 125.35692416  87.68715601 110.2890169
 140.42483142 137.19599415 211.45925137 129.66204052 205.00157683
  89.83971419 178.09459958 110.2890169  132.89087779 244.82390316
 151.18762232 170.56064595  96.29738873 105.98390054 146.88250596
 248.05274043 147.95878505 157.64529687  87.68715601 128.58576143
 145.80622687 182.39971594 118.89924962 169.48436686  79.07692329
  95.22110964 149.03506414 185.62855321  75.84808602 182.39971594
 131.8145987  128.58576143 180.24715776]
Mean Squared Error¶
In [20]:
# train_loss = model.score(y_train, y_pred)
train_loss = metrics.mean_squared_error(y_train, y_pred1)
test_loss  = metrics.mean_squared_error(y_test,y_pred)
# test_loss = model.score(y_test, y_pred)

print("Train Loss:", train_loss)
print("Test Loss:", test_loss)
Train Loss: 3854.11265207582
Test Loss: 4061.8259284949268
Plotting Grpah using Matplotlib¶
In [21]:
# Plotting the training data, test data, and the model
plt.scatter(X_train, y_train, color ='green', label='Training Data')
plt.scatter(X_test, y_test, color ='blue', label='Test Data')
plt.plot(X_train, y_pred1, color='red', label='Linear Regression')
plt.xlabel('BMI')
plt.ylabel('Disease Progression')
plt.title('Linear Regression: BMI vs Disease Progression')
plt.legend()
plt.show()

print("Model Parameters:")
print('Coefficients: \n', model.coef_)
print('Intercept: \n', model.intercept_)
Model Parameters:
Coefficients: 
 [998.57768914]
Intercept: 
 152.00335421448167
Plotting Grpah using Seaborn¶
In [22]:
plt.scatter(X_train, y_train, color='orange', label='Training Data')
plt.scatter(X_test, y_test, color='red', label='Test Data')
plt.plot(X_train, y_pred1, color='blue', label='Linear Regression Model')
plt.xlabel('BMI (Body Mass Index)')
plt.ylabel('Disease Progression')
plt.legend()
plt.show()
print("Model Parameters:")
print('Coefficients: \n', model.coef_)
print('Intercept: \n', model.intercept_)
Model Parameters:
Coefficients: 
 [998.57768914]
Intercept: 
 152.00335421448167
Plotting Grpah using Plotly¶
In [24]:
fig = go.Figure()
fig.add_trace(go.Scatter(x=X_train, y=y_train, mode='markers', name='Training Data'))
fig.add_trace(go.Scatter(x=X_test, y=y_test, mode='markers', name='Test Data'))
fig.add_trace(go.Scatter(x=X_train, y=y_pred1, mode='lines', name='Linear Regression Model'))
fig.update_layout(title='Linear Regression - BMI vs Disease Progression',
                  xaxis_title='BMI (Body Mass Index)',
                  yaxis_title='Disease Progression')
print("Model Parameters:")
print('Coefficients: \n', model.coef_)
print('Intercept: \n', model.intercept_)
fig.show()
Model Parameters:
Coefficients: 
 [998.57768914]
Intercept: 
 152.00335421448167
From the graphs, it can be observed as :¶
  1. The model may be overfitting the training data which indicates as a bad fit.
  2. There is a positive relationship between BMI and disease progression.